{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Character based RNN language model\n",
    "(c) Deniz Yuret, 2019. Based on http://karpathy.github.io/2015/05/21/rnn-effectiveness."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Objectives: Learn to define and train a character based language model and generate text from it. Minibatch blocks of text. Keep a persistent RNN state between updates. Train a Shakespeare generator and a Julia programmer using the same type of model.\n",
    "* Prerequisites: [RNN basics](60.rnn.ipynb), [Iterators](25.iterators.ipynb)\n",
    "* New functions:\n",
    "[converge](http://denizyuret.github.io/Knet.jl/latest/reference/#Knet.converge)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set display width, load packages, import symbols\n",
    "ENV[\"COLUMNS\"]=72\n",
    "using Pkg; haskey(Pkg.installed(),\"Knet\") || Pkg.add(\"Knet\")\n",
    "using Statistics: mean\n",
    "using Base.Iterators: cycle\n",
    "using IterTools: takenth\n",
    "using Knet: Knet, AutoGrad, Data, param, param0, mat, RNN, dropout, value, nll, adam, minibatch, progress!, converge"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct Embed; w; end\n",
    "\n",
    "Embed(vocab::Int,embed::Int)=Embed(param(embed,vocab))\n",
    "\n",
    "(e::Embed)(x) = e.w[:,x]  # (B,T)->(X,B,T)->rnn->(H,B,T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct Linear; w; b; end\n",
    "\n",
    "Linear(input::Int, output::Int)=Linear(param(output,input), param0(output))\n",
    "\n",
    "(l::Linear)(x) = l.w * mat(x,dims=1) .+ l.b  # (H,B,T)->(H,B*T)->(V,B*T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's define a chain of layers\n",
    "struct Chain\n",
    "    layers\n",
    "    Chain(layers...) = new(layers)\n",
    "end\n",
    "(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)\n",
    "(c::Chain)(x,y) = nll(c(x),y)\n",
    "(c::Chain)(d::Data) = mean(c(x,y) for (x,y) in d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CharLM (generic function with 1 method)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The h=0,c=0 options to RNN enable a persistent state between iterations\n",
    "CharLM(vocab::Int,embed::Int,hidden::Int; o...) = \n",
    "    Chain(Embed(vocab,embed), RNN(embed,hidden;h=0,c=0,o...), Linear(hidden,vocab))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and test utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainresults (generic function with 1 method)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For running experiments\n",
    "function trainresults(file,maker,chars)\n",
    "    if (print(\"Train from scratch? \"); readline()[1]=='y')\n",
    "        model = maker()\n",
    "        a = adam(model,cycle(dtrn))\n",
    "        b = (exp(model(dtst)) for _ in takenth(a,100))\n",
    "        c = converge(b, alpha=0.1)\n",
    "        progress!(p->p.currval, c)\n",
    "        Knet.save(file,\"model\",model,\"chars\",chars)\n",
    "    else\n",
    "        isfile(file) || download(\"http://people.csail.mit.edu/deniz/models/tutorial/$file\",file)\n",
    "        model,chars = Knet.load(file,\"model\",\"chars\")\n",
    "    end\n",
    "    Knet.gc() # To save gpu memory\n",
    "    return model,chars\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# To generate text from trained models\n",
    "function generate(model,chars,n)\n",
    "    function sample(y)\n",
    "        p = Array(exp.(y)); r = rand()*sum(p)\n",
    "        for j=1:length(p); (r -= p[j]) < 0 && return j; end\n",
    "    end\n",
    "    x = 1\n",
    "    reset!(model)\n",
    "    for i=1:n\n",
    "        y = model([x])\n",
    "        x = sample(y)\n",
    "        print(chars[x])\n",
    "    end\n",
    "    println()\n",
    "end\n",
    "\n",
    "reset!(m::Chain)=(for r in m.layers; r isa RNN && (r.c=r.h=0); end);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Complete Works of William Shakespeare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "RNNTYPE = :lstm\n",
    "BATCHSIZE = 256\n",
    "SEQLENGTH = 100\n",
    "VOCABSIZE = 84\n",
    "INPUTSIZE = 168\n",
    "HIDDENSIZE = 334\n",
    "NUMLAYERS = 1;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(\"4934845-element Array{UInt8,1}\", \"526731-element Array{UInt8,1}\", \"84-element Array{Char,1}\")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load 'The Complete Works of William Shakespeare'\n",
    "include(Knet.dir(\"data\",\"gutenberg.jl\"))\n",
    "trn,tst,shakechars = shakespeare()\n",
    "map(summary,(trn,tst,shakechars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "    Cheated of feature by dissembling nature,\r\n",
      "    Deform'd, unfinish'd, sent before my time\r\n",
      "    Into this breathing world scarce half made up,\r\n",
      "    And that so lamely and unfashionable\r\n",
      " \n"
     ]
    }
   ],
   "source": [
    "# Print a sample\n",
    "println(string(shakechars[trn[1020:1210]]...))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(192, 20)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Minibatch data\n",
    "function mb(a)\n",
    "    N = length(a) ÷ BATCHSIZE\n",
    "    x = reshape(a[1:N*BATCHSIZE],N,BATCHSIZE)' # reshape full data to (B,N) with contiguous rows\n",
    "    minibatch(x[:,1:N-1], x[:,2:N], SEQLENGTH) # split into (B,T) blocks \n",
    "end\n",
    "dtrn,dtst = mb.((trn,tst))\n",
    "length.((dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(\"256×100 Array{UInt8,2}\", \"256×100 Array{UInt8,2}\")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary.(first(dtrn))  # each x and y have dimensions (BATCHSIZE,SEQLENGTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train from scratch? stdin> n\n"
     ]
    }
   ],
   "source": [
    "# [180, 06:58, 2.32s/i] 3.3026385\n",
    "shakemaker() = CharLM(VOCABSIZE, INPUTSIZE, HIDDENSIZE; rnnType=RNNTYPE, numLayers=NUMLAYERS)\n",
    "shakemodel,shakechars = trainresults(\"shakespeare132.jld2\", shakemaker, shakechars);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exp(shakemodel(dtst))  # Perplexity = 3.3150165f0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PARYNENGET           With her good wife\n",
      "\n",
      "                Enter KING PHILJES; NECNOLBUS\n",
      "  PAGE OVERNAS, the a darkey sounds in colours.\n",
      "    Petruchio's gart, sweet Jamor, alike how sliolat\n",
      "     Up her shrine, and all these will not be not tall-\n",
      "    But if thou chost of thoo I will never hang\n",
      "    To fright a follier wrongs with drum.  \n",
      "  Ophin. It is worth else.\n",
      "  Hot. A very sign is a deed-proud. How do't thy little fell!  \n",
      "  Lear. I am lost ourselves, and thou begst it from Some for the day\n",
      "     place in last, that diverses thy prince women\n",
      "     That it is elft his beard; the rue too late\n",
      "     Not truck at me. It give them for quires.\n",
      "  IAGO. I yet rough this if his heart of eternal women\n",
      "    I should accuse your voice.\n",
      "  Lear. Yet, sir, a tear.\n",
      "    And stock, my lord.\n",
      "  Friar. Farewell!\n",
      "  Beat. What a base devil sparefull some gentry should have will not talk with the\n",
      "    cap\n",
      "    unseas'd; that the Hotsoman knew'st it now, come!--Let's buring one\n",
      "    reason to ho\n"
     ]
    }
   ],
   "source": [
    "generate(shakemodel,shakechars,1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Julia programmer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "RNNTYPE = :lstm\n",
    "BATCHSIZE = 64\n",
    "SEQLENGTH = 64\n",
    "INPUTSIZE = 512\n",
    "VOCABSIZE = 128\n",
    "HIDDENSIZE = 512\n",
    "NUMLAYERS = 2;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2832949"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Read julia base library source code\n",
    "base = joinpath(Sys.BINDIR, Base.DATAROOTDIR, \"julia\", \"base\")\n",
    "text = \"\"\n",
    "for (root,dirs,files) in walkdir(base)\n",
    "    for f in files\n",
    "        global text\n",
    "        f[end-2:end] == \".jl\" || continue\n",
    "        text *= read(joinpath(root,f), String)\n",
    "    end\n",
    "    # println((root,length(files),all(f->contains(f,\".jl\"),files)))\n",
    "end\n",
    "length(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "197×2 Array{Any,2}:\n",
       " ' '   593337\n",
       " 'e'   183265\n",
       " 't'   146524\n",
       " 'n'   128856\n",
       " 'i'   119906\n",
       " 'r'   109778\n",
       " 'a'   103761\n",
       " 's'    99377\n",
       " 'o'    94416\n",
       " '\\n'   89082\n",
       " 'l'    70424\n",
       " 'd'    61536\n",
       " 'c'    57355\n",
       " ⋮           \n",
       " 'μ'        1\n",
       " '”'        1\n",
       " '≶'        1\n",
       " 'ᶜ'        1\n",
       " '¾'        1\n",
       " '⁹'        1\n",
       " '⛵'        1\n",
       " '⁷'        1\n",
       " '⊗'        1\n",
       " 'Λ'        1\n",
       " '⋅'        1\n",
       " 'ϒ'        1"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Find unique chars, sort by frequency, assign integer ids.\n",
    "charcnt = Dict{Char,Int}()\n",
    "for c in text; charcnt[c]=1+get(charcnt,c,0); end\n",
    "juliachars = sort(collect(keys(charcnt)), by=(x->charcnt[x]), rev=true)\n",
    "charid = Dict{Char,Int}()\n",
    "for i=1:length(juliachars); charid[juliachars[i]]=i; end\n",
    "hcat(juliachars, map(c->charcnt[c],juliachars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2832949, 2308661, 524288)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Keep only VOCABSIZE most frequent chars, split into train and test\n",
    "data = map(c->charid[c], collect(text))\n",
    "data[data .> VOCABSIZE] .= VOCABSIZE\n",
    "ntst = 1<<19\n",
    "tst = data[1:ntst]\n",
    "trn = data[1+ntst:end]\n",
    "length.((data,trn,tst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "For `n <\n",
      "0`, this is equivalent to `x << -n`.\n",
      "\n",
      "# Examples\n",
      "```jldoctest\n",
      "julia> Int8(13) >> 2\n",
      "3\n",
      "\n",
      "julia> bitstring(Int8(13))\n",
      "\"00001101\"\n",
      "\n",
      "julia> bitstring(Int8(3))\n",
      "\"00000011\"\n",
      "\n",
      "julia> Int8(-14) >> 2\n",
      "-4\n",
      "\n",
      "julia> bitstring(Int8(-14))\n",
      "\"11110010\"\n",
      "\n",
      "julia> bitstring(Int8(-4))\n",
      "\"11111100\"\n",
      "```\n",
      "See also [`>>>`](@ref), [`<<`](@ref).\n",
      "\"\"\"\n",
      "function >>(x::Integer, c::Integer)\n",
      "    @_inline_meta\n",
      "    if c isa UInt\n",
      "        throw(MethodError(>>, (x, c)))\n",
      "    end\n",
      "    typemin(Int) <= c <= typemax(Int) && return x >> (c % Int)\n",
      "    (x >= 0 || c < 0) && return zero(x) >> 0\n",
      "    oftype(x, -1)\n",
      "end\n",
      ">>(x::Integer, c::Int) = c >= 0 ? x >> unsigned(c) : x << unsigned(-c)\n",
      "\n",
      "\"\"\"\n",
      "    >>>(x, n)\n",
      "\n",
      "Unsigned right bit shift operator, `x >>> n`. For `n >= 0`, the result is `x`\n",
      "shifted right by `n` bits, where `n >= 0`, filling with `0`s. For `n < 0`, this\n",
      "is equivalent to `x << -n`.\n",
      "\n",
      "For [`Unsigned`](@ref) integer types, this is equivalent to [`>>`](@ref). For\n",
      "[`Signed`](@ref) integer types, this is equivalent to `signed(unsigned(x) \n"
     ]
    }
   ],
   "source": [
    "# Print a sample\n",
    "r = rand(1:(length(trn)-1000))\n",
    "println(string(juliachars[trn[r:r+1000]]...)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(563, 127)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Minibatch data\n",
    "function mb(a)\n",
    "    N = length(a) ÷ BATCHSIZE\n",
    "    x = reshape(a[1:N*BATCHSIZE],N,BATCHSIZE)' # reshape full data to (B,N) with contiguous rows\n",
    "    minibatch(x[:,1:N-1], x[:,2:N], SEQLENGTH) # split into (B,T) blocks \n",
    "end\n",
    "dtrn,dtst = mb.((trn,tst))\n",
    "length.((dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(\"64×64 Array{Int64,2}\", \"64×64 Array{Int64,2}\")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary.(first(dtrn))  # each x and y have dimensions (BATCHSIZE,SEQLENGTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train from scratch? stdin> n\n"
     ]
    }
   ],
   "source": [
    "# [150, 05:04, 2.03s/i] 3.2988634\n",
    "juliamaker() = CharLM(VOCABSIZE, INPUTSIZE, HIDDENSIZE; rnnType=RNNTYPE, numLayers=NUMLAYERS)\n",
    "juliamodel,juliachars = trainresults(\"juliacharlm132.jld2\", juliamaker, juliachars);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exp(juliamodel(dtst))  # Perplexity = 3.8615866f0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x             t)\n",
      "        @inline occursin = fillpadd(min(one(Tr,UInt64(1))) + 1*n - 4Valued_byte\n",
      "        l += 1\n",
      "        push!(based_end(x,cptr))\n",
      "        m = read(f, Int32(0))\n",
      "        if padding == 2\n",
      "            convert(P, dump(xyp, xmpfilebuf)) + leading_zeros(i2, len)\n",
      "            push!(blk.args, :(print(out, 'x'), 'x'))\n",
      "            elseif pointext = get(io,:SignificandSoze(b)->exponent_biass[], pb, 0)\n",
      "            cad = '5' ? fixupdic(dept)\n",
      "            exponent_base = 32 && (p == 17) || buf == 1\n",
      "            hhigh = UInt32(b)\n",
      "            t = BigFloat(width, p)\n",
      "            higherror(a)\n",
      "        end\n",
      "    elseif arrayref === Float32(heanonizerox(aft))\n",
      "        return Base.extread_block\n",
      "    else\n",
      "        len = 0 # final errors out might\n",
      "        fmp = dispatchunk(domtree, self)\n",
      "        x = point\n",
      "    end\n",
      "    blk = ccall(:GetStypes, pad, Any, (Csize_t, Base), s, false, 0)\n",
      "    while !isempty(Woppend[keys(prevind(W, first_byteslice(str))))>0.0, 0,\n",
      "                blust_idx > block.succs[2]]\n",
      "        \n"
     ]
    }
   ],
   "source": [
    "generate(juliamodel,juliachars,1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "julia.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Julia 1.3.1",
   "language": "julia",
   "name": "julia-1.3"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.3.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
